{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Explaining a Question Answering Transformers Model\n", "\n", "Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import transformers\n", "\n", "import shap\n", "\n", "# load the model\n", "pmodel = transformers.pipeline(\"question-answering\")\n", "tokenized_qs = None # variable to store the tokenized data\n", "\n", "\n", "# define two predictions, one that outputs the logits for the range start,\n", "# and the other for the range end\n", "def f(questions, tokenized_qs, start):\n", " outs = []\n", " for q in questions:\n", " idx = np.argwhere(np.array(tokenized_qs[\"input_ids\"]) == pmodel.tokenizer.sep_token_id)[\n", " 0, 0\n", " ] # this code assumes that there is only one sentence in data\n", " d = tokenized_qs.copy()\n", " d[\"input_ids\"][:idx] = q[:idx]\n", " d[\"input_ids\"][idx + 1 :] = q[idx + 1 :]\n", " out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})\n", " logits = out.start_logits if start else out.end_logits\n", " outs.append(logits.reshape(-1).detach().numpy())\n", " return outs\n", "\n", "\n", "def tokenize_data(data):\n", " for q in data:\n", " question, context = q.split(\"[SEP]\")\n", " tokenized_data = pmodel.tokenizer(question, context)\n", " return tokenized_data # this code assumes that there is only one sentence in data\n", "\n", "\n", "def f_start(questions):\n", " return f(questions, tokenized_qs, True)\n", "\n", "\n", "def f_end(questions):\n", " return f(questions, tokenized_qs, False)\n", "\n", "\n", "# attach a dynamic output_names property to the models so we can plot the tokens at each output position\n", "def out_names(inputs):\n", " question, context = inputs.split(\"[SEP]\")\n", " d = pmodel.tokenizer(question, context)\n", " return [pmodel.tokenizer.decode([id]) for id in d[\"input_ids\"]]\n", "\n", "\n", "f_start.output_names = out_names\n", "f_end.output_names = out_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explain the starting positions\n", "\n", "Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model's native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4b4f99a75cbc4f3984b0bfb65fbaaff9", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/498 [00:00\n", "
\n", "
[0]
\n", "
\n", "
\n", "\n", "
outputs
\n", "
[CLS]
\n", "
What
\n", "
is
\n", "
on
\n", "
the
\n", "
table
\n", "
?
\n", "
[SEP]
\n", "
When
\n", "
I
\n", "
got
\n", "
home
\n", "
today
\n", "
I
\n", "
saw
\n", "
my
\n", "
cat
\n", "
on
\n", "
the
\n", "
table
\n", "
,
\n", "
and
\n", "
my
\n", "
frog
\n", "
on
\n", "
the
\n", "
floor
\n", "
.
\n", "
[SEP]


-1-4-725-1.26347-1.26347base value-3.97193-3.97193f[CLS](inputs)0.498 What 0.214 on 0.032 home 0.026 got 0.022 the 0.013 today -0.937 ? -0.459 is -0.374 table -0.254 cat -0.25 on -0.233 on the floor -0.182 and -0.137 saw -0.126 , -0.108 . -0.094 my -0.086 frog -0.078 my -0.059 When -0.05 the -0.049 I -0.037 table -0.0 I
inputs
0.0
0.498
What
-0.459
is
0.214
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.059
When
-0.049
I
0.026
got
0.032
home
0.013
today
-0.0
I
-0.137
saw
-0.078
my
-0.254
cat
-0.25
on
-0.05
the
-0.037
table
-0.126
,
-0.182
and
-0.094
my
-0.086
frog
-0.233 / 3
on the floor
-0.108
.
0.0